{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "post_training_quant.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Y8E0lw5eYWm"
      },
      "source": [
        "# Post-training dynamic range quantization\n",
        "\n",
        "**Learning Objectives**\n",
        "  1. We will learn how to train a TensorFlow model.\n",
        "  2. We will learn how to load the model into an interpreter.\n",
        "  3. We will learn how to evaluate the models.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BTC1rDAuei_1"
      },
      "source": [
        "## Introduction\n",
        "\n",
        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
        "converting weights to 8 bit precision as part of model conversion from\n",
        "tensorflow graphdefs to TensorFlow Lite's flat buffer format. Dynamic range quantization achieves a 4x reduction in the model size. In addition, TFLite supports on the fly quantization and dequantization of activations to allow for:\n",
        "\n",
        "1.  Using quantized kernels for faster implementation when available.\n",
        "2.  Mixing of floating-point kernels with quantized kernels for different parts\n",
        "    of the graph.\n",
        "\n",
        "The activations are always stored in floating point. For ops that\n",
        "support quantized kernels, the activations are quantized to 8 bits of precision\n",
        "dynamically prior to processing and are de-quantized to float precision after\n",
        "processing. Depending on the model being converted, this can give a speedup over\n",
        "pure floating point computation.\n",
        "\n",
        "In contrast to\n",
        "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/r1.14/tensorflow/contrib/quantize)\n",
        ", the weights are quantized post training and the activations are quantized dynamically \n",
        "at inference in this method.\n",
        "Therefore, the model weights are not retrained to compensate for quantization\n",
        "induced errors. It is important to check the accuracy of the quantized model to\n",
        "ensure that the degradation is acceptable.\n",
        "\n",
        "This tutorial trains an MNIST model from scratch, checks its accuracy in\n",
        "TensorFlow, and then converts the model into a Tensorflow Lite flatbuffer\n",
        "with dynamic range quantization. Finally, it checks the\n",
        "accuracy of the converted model and compare it to the original float model.\n",
        "\n",
        "Each learning objective will correspond to a __#TODO__ in the [student lab notebook](https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/deepdive2/production_ml/labs/post_training_quant.ipynb) -- try to complete that notebook first before reviewing this solution notebook.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2XsEP17Zelz9"
      },
      "source": [
        "## Build an MNIST model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dDqqUIZjZjac"
      },
      "source": [
        "### Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gyqAw1M9lyab"
      },
      "source": [
        "# Importing necessary modules\n",
        "import logging\n",
        "logging.getLogger(\"tensorflow\").setLevel(logging.DEBUG)\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "import numpy as np\n",
        "import pathlib"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQ6Q0qqKZogR"
      },
      "source": [
        "### Train a TensorFlow model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hWSAjQWagIHl"
      },
      "source": [
        "# Load MNIST dataset\n",
        "mnist = keras.datasets.mnist\n",
        "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
        "\n",
        "# Normalize the input image so that each pixel value is between 0 to 1.\n",
        "train_images = train_images / 255.0\n",
        "test_images = test_images / 255.0\n",
        "\n",
        "# Define the model architecture\n",
        "model = keras.Sequential([\n",
        "  keras.layers.InputLayer(input_shape=(28, 28)),\n",
        "  keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
        "  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "  keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "  keras.layers.Flatten(),\n",
        "  keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "# Train the digit classification model\n",
        "model.compile(optimizer='adam',\n",
        "              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "              metrics=['accuracy'])\n",
        "# TODO 1 - Here is your code.\n",
        "model.fit(\n",
        "  train_images,\n",
        "  train_labels,\n",
        "  epochs=1,\n",
        "  validation_data=(test_images, test_labels)\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5NMaNZQCkW9X"
      },
      "source": [
        "For the example, since you trained the model for just a single epoch, so it only trains to ~96% accuracy.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xl8_fzVAZwOh"
      },
      "source": [
        "### Convert to a TensorFlow Lite model\n",
        "\n",
        "Using the Python [TFLiteConverter](https://www.tensorflow.org/lite/convert/python_api), you can now convert the trained model into a TensorFlow Lite model.\n",
        "\n",
        "Now load the model using the `TFLiteConverter`:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_i8B2nDZmAgQ"
      },
      "source": [
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "tflite_model = converter.convert()"
      ],
      "execution_count": null,
      "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "INFO:tensorflow:Assets written to: /tmp/tmpkn2_3ey6/assets\n"
         ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F2o2ZfF0aiCx"
      },
      "source": [
        "Write it out to a tflite file:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vptWZq2xnclo"
      },
      "source": [
        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ie9pQaQrn5ue"
      },
      "source": [
        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
        "tflite_model_file.write_bytes(tflite_model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7BONhYtYocQY"
      },
      "source": [
        "To quantize the model on export, set the `optimizations` flag to optimize for size:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g8PUvLWDlmmz"
      },
      "source": [
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "tflite_quant_model = converter.convert()\n",
        "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
        "tflite_model_quant_file.write_bytes(tflite_quant_model)"
      ],
      "execution_count": null,
      "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
         "INFO:tensorflow:Assets written to: /tmp/tmppp6eqzty/assets\n",
         "INFO:tensorflow:Assets written to: /tmp/tmppp6eqzty/assets\n",
         "23888\n"
         ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PhMmUTl4sbkz"
      },
      "source": [
        "Note how the resulting file, is approximately `1/4` the size."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JExfcfLDscu4"
      },
      "source": [
        "!ls -lh {tflite_models_dir}"
      ],
      "execution_count": null,
      "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "total 108K\n",
          "-rw-r--r-- 1 root root 24K Apr 16 07:06 mnist_model_quant.tflite\n",
          "-rw-r--r-- 1 root root 83K Apr 16 07:06 mnist_model.tflite\n"
         ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L8lQHMp_asCq"
      },
      "source": [
        "## Run the TFLite models\n",
        "\n",
        "Run the TensorFlow Lite model using the Python TensorFlow Lite\n",
        "Interpreter.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ap_jE7QRvhPf"
      },
      "source": [
        "### Load the model into an interpreter"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jn16Rc23zTss"
      },
      "source": [
        "# TODO 2 - Here is your code.\n",
        "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
        "interpreter.allocate_tensors()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J8Pztk1mvNVL"
      },
      "source": [
        "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))\n",
        "interpreter_quant.allocate_tensors()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2opUt_JTdyEu"
      },
      "source": [
        "### Test the model on one image"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AKslvo2kwWac"
      },
      "source": [
        "# Here, expanding the shape of an array\n",
        "test_image = np.expand_dims(test_images[0], axis=0).astype(np.float32)\n",
        "\n",
        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "output_index = interpreter.get_output_details()[0][\"index\"]\n",
        "\n",
        "interpreter.set_tensor(input_index, test_image)\n",
        "interpreter.invoke()\n",
        "predictions = interpreter.get_tensor(output_index)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XZClM2vo3_bm"
      },
      "source": [
        "import matplotlib.pylab as plt\n",
        "\n",
        "# Displaying data as an image\n",
        "plt.imshow(test_images[0])\n",
        "template = \"True:{true}, predicted:{predict}\"\n",
        "_ = plt.title(template.format(true= str(test_labels[0]),\n",
        "                              predict=str(np.argmax(predictions[0]))))\n",
        "plt.grid(False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LwN7uIdCd8Gw"
      },
      "source": [
        "### Evaluate the models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "05aeAuWjvjPx"
      },
      "source": [
        "# A helper function to evaluate the TF Lite model using \"test\" dataset.\n",
        "def evaluate_model(interpreter):\n",
        "  input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "  output_index = interpreter.get_output_details()[0][\"index\"]\n",
        "\n",
        "  # Run predictions on every image in the \"test\" dataset.\n",
        "  prediction_digits = []\n",
        "  for test_image in test_images:\n",
        "    # Pre-processing: add batch dimension and convert to float32 to match with\n",
        "    # the model's input data format.\n",
        "    # TODO 3 - Here is your code.\n",
        "    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n",
        "    interpreter.set_tensor(input_index, test_image)\n",
        "\n",
        "    # Run inference.\n",
        "    interpreter.invoke()\n",
        "\n",
        "    # Post-processing: remove batch dimension and find the digit with highest\n",
        "    # probability.\n",
        "    output = interpreter.tensor(output_index)\n",
        "    digit = np.argmax(output()[0])\n",
        "    prediction_digits.append(digit)\n",
        "\n",
        "  # Compare prediction results with ground truth labels to calculate accuracy.\n",
        "  accurate_count = 0\n",
        "  for index in range(len(prediction_digits)):\n",
        "    if prediction_digits[index] == test_labels[index]:\n",
        "      accurate_count += 1\n",
        "  accuracy = accurate_count * 1.0 / len(prediction_digits)\n",
        "\n",
        "  return accuracy"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DqXBnDfJ7qxL"
      },
      "source": [
        "print(evaluate_model(interpreter))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Km3cY9ry8ZlG"
      },
      "source": [
        "Repeat the evaluation on the dynamic range quantized model to obtain:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-9cnwiPp6EGm"
      },
      "source": [
        "print(evaluate_model(interpreter_quant))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L7lfxkor8pgv"
      },
      "source": [
        "In this example, the compressed model has no difference in the accuracy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M0o1FtmWeKZm"
      },
      "source": [
        "## Optimizing an existing model\n",
        "\n",
        "Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
        "  Pre-trained frozen graph for resnet-v2-101 is available on\n",
        "  [Tensorflow Hub](https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4).\n",
        "\n",
        "You can convert the frozen graph to a TensorFLow Lite flatbuffer with quantization by:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jrXZxSJiJfYN"
      },
      "source": [
        "import tensorflow_hub as hub\n",
        "\n",
        "resnet_v2_101 = tf.keras.Sequential([\n",
        "  keras.layers.InputLayer(input_shape=(224, 224, 3)),\n",
        "  hub.KerasLayer(\"https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4\")\n",
        "])\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(resnet_v2_101)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LwnV4KxwVEoG"
      },
      "source": [
        "# Convert to TF Lite without quantization\n",
        "resnet_tflite_file = tflite_models_dir/\"resnet_v2_101.tflite\"\n",
        "resnet_tflite_file.write_bytes(converter.convert())"
      ],
      "execution_count": null,
      "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
         "INFO:tensorflow:Assets written to: /tmp/tmpypcclov6/assets\n",
         "INFO:tensorflow:Assets written to: /tmp/tmpypcclov6/assets\n",
         "178509356\n"
         ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2qkZD0VoVExe"
      },
      "source": [
        "# Convert to TF Lite with quantization\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "resnet_quantized_tflite_file = tflite_models_dir/\"resnet_v2_101_quantized.tflite\"\n",
        "resnet_quantized_tflite_file.write_bytes(converter.convert())"
      ],
      "execution_count": null,
      "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
         "INFO:tensorflow:Assets written to: /tmp/tmpl8ubb_mk/assets\n",
         "INFO:tensorflow:Assets written to: /tmp/tmpl8ubb_mk/assets\n",
         "46256896\n"
         ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vhOjeg1x9Knp"
      },
      "source": [
        "!ls -lh {tflite_models_dir}/*.tflite"
      ],
      "execution_count": null,
      "outputs": [
        {
         "name": "stdout",
         "output_type": "stream",
         "text": [
         "-rw-r--r-- 1 root root  24K Apr 16 07:06 /tmp/mnist_tflite_models/mnist_model_quant.tflite\n",
         "-rw-r--r-- 1 root root  83K Apr 16 07:06 /tmp/mnist_tflite_models/mnist_model.tflite\n",
         "-rw-r--r-- 1 root root  45M Apr 16 07:07 /tmp/mnist_tflite_models/resnet_v2_101_quantized.tflite\n",
         "-rw-r--r-- 1 root root 171M Apr 16 07:07 /tmp/mnist_tflite_models/resnet_v2_101.tflite\n"
         ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qqHLaqFMCjRZ"
      },
      "source": [
        "The model size reduces from 171 MB to 43 MB.\n",
        "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification).\n",
        "\n",
        "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
      ]
    }
  ]
}
